//mmatch.ado
//"Mata Match"
//This ado file implements nearest neighbor matching and bias corrected matching
//On the Negative Side: no standard errors, only effect of treatment on the treated (att)
//On the Plus Side: This code runs in a fraction of the time required by the original Stata
//code from Abadie, Drukker, Herr, and Imbens 
//
//Example syntaxes:
//pair matching on covariates X1 and X2
//mmatch Y T X1 X2, m(1) tc(att)
//kth nearest neighbor matching on X1 and X2 with 4 matches
//mmatch Y T X1 X2, m(4) tc(att)
//bias adjusted matching on X1 and X2 with bias adjustment for Z1 and Z2
//mmatch Y T X1 X2, biasadj(Z1 Z2) m(4) tc(att)
//
//JRM 4.10.2012
//5.1.2012: Update to accomodate collinearity of regressors


program mmatch, rclass
  version 11.1
  syntax varlist(numeric min=3) [if] [in], ///
    [m(integer 1) BIASadj(string) tc(string) Kweight(string)]
  if ("`tc'"~="att") {
    di in red "Estimation of treatment effect on the treated is only supported option."
    di in red "Please add option tc(att)."
  }
  else {
    //In this branch, the user has passed "tc(att)" as an option, but may not have
    //passed "biasadj(Z1 Z2)" (for example) as an option
    marksample touse           //this encodes whether the [if] and [in] options imply using a subset of obs
    if ("`kweight'"=="") {
      tempvar kweight
    }
    cap drop `kweight'
    gen double `kweight'=.
    tempvar touse2
    tokenize `varlist'
    tempvar T
    gen byte `T'=`2'
    gen `touse2'=`touse'*(1-`T')
    if ("`biasadj'"=="") {
      mata: mmatch("`varlist'","`touse'","`touse2'","`biasadj'","`kweight'",`m',0)
      return scalar NN=`NN'
    }
    else {
      //Compute NN and BCM
      mata: mmatch("`varlist'","`touse'","`touse2'","`biasadj'","`kweight'",`m',1)
      return scalar NN=`NN'
      return scalar BCM=`BCM'
    }
    replace `kweight'=1 if `T'==1
    ereturn clear
  }
end

version 11.1
mata:

  real matrix function transpose(real matrix A) {
    return(transposeonly(A))
  }
  
  void mmatch(string scalar varlist1, string scalar touse, string scalar touse2, ///
              string scalar varlist2, string scalar KMvar, ///
              real scalar M, real scalar adjust) {
    real matrix A, B, X, XY, XY1, XY0, X0, X1, D, matches, Z, Z1, Z0, Z0hat, w
    real colvector Y, T, Y0, Y1, Y0hat, jofi, n_jofi, KM, b
    real rowvector muZ0
    real scalar K, n1, n0, h, m, k, i, NN, BMt, BCM, muY0
    st_view(A=., ., varlist1, touse)
    st_view(KM=., ., KMvar, touse2)
    K=cols(A)-2
    Y=A[.,1]
    T=A[.,2]
    X=A[.,2+1::2+K]                //we don't need A after this, so we recycle it below
    n1=sum(T)
    n0=rows(T)-n1
    //X in standardized units
    X=(X:-mean(X)):/transpose(sqrt(diagonal(variance(X))))
    XY=X,Y //use this to make sure keep order of X0,Y0,X1,Y1 consistent
    XY1=select(XY,T)
    XY0=select(XY,1:-T)
    X0=XY0[.,1::K]  //note that the ith row of X0 corresponds to the ith row Y0
    Y0=XY0[.,K+1]
    X1=XY1[.,1::K]  //and the ith row of X1 corresponds to the ith row of Y1
    Y1=XY1[.,K+1]
    if (adjust>0) {
      st_view(Z=., ., varlist2, touse) //variables used for bias adjustment
      Z1=select(Z,T)
      Z0=select(Z,1:-T)
    }
    //Next form n1-by-n0 matrix of differences of covariates
    //using normalized Euclidean metric (covariates already normalized)
    D=(X1[.,1] :- J(n1,1,transpose(X0[.,1]))):^2 //n1xn0 matrix of squared differences
    for (k=2; k<=K; k++) {
      D=D+(X1[.,k] :- J(n1,1,transpose(X0[.,k]))):^2 //incrementing the metric
    }
    h=1e-7
    D=round(sqrt(D)/h)*h //note that a bit of discretization is needed to match
                         //the Stata code from Abadie, Drukker, Herr, and Imbens
                         //We could eliminate the rounding, but it hardly matters
    //loop over T=1 obs. For each i, find closest matches. 
    //Handling ties right is a bit tricky.   See d_M(i) def'n
    Y0hat=J(n1,1,.)               //Y0hat is imputed counterfactual outcome
    Z0hat=J(n1,cols(Z),.)         //imputed covariates associated with bias reduction
    matches=J(n1,n0,.)            //will contain index of control unit matches to i, i.e., J_M(i)
    for (i=1; i<=n1; i++) {
      m=1                           //initialize
      minindex(D[i,.],m,jofi=.,w=.) //jofi records column of closest match. With ties, jofi
                                    //is a column vector with closest matches, plural.
      //get the right number of matches... see d_M(i) definition
      while (rows(jofi)<M) {
        minindex(D[i,.],m++,jofi=.,w=.) //jofi is now a longer list of matches
                                        //note that m increases by 1 each time rows(jofi)<M
      }
      //at end of loop above, jofi has enough elements to provide M matches
      //update the matches matrix
      matches[i,1::rows(jofi)]=transpose(jofi)
      //
      Y0hat[i]=mean(Y0[jofi])
      if (adjust>0) {
        Z0hat[i,.]=mean(Z0[jofi,.])
      }
    }
    NN=mean(Y1-Y0hat)
    //put back into Stata
    st_local("NN",strofreal(NN,"%10.0g"))
    //form KM(j), the number of times control unit j is used as a match
    KM[.,.]=J(n0,1,0) //reset to zero
    n_jofi=rownonmissing(matches) //count of matches used for treatment unit i
    for (i=1; i<=n1; i++) {
      for (m=1; m<=n_jofi[i]; m++) {
        jofi=matches[i,m]   //this is the mth match for treatment unit i
        KM[jofi]=KM[jofi]+1/n_jofi[i] //increment KM(j) appropriately
      }
    }
    //If bias-adjustment requested, perform that as well
    if (adjust>0) {
      //We have NN, the nearest neighbor estimator and now want to bias correct
      muZ0=mean(Z0,KM)
      muY0=mean(Y0,KM)
      A=quadcrossdev(Z0,muZ0,KM,Z0,muZ0)
      B=quadcrossdev(Z0,muZ0,KM,Y0,muY0)
      b=qrsolve(A,B)                     //note that we do not need the constant
                                         //we are using qrsolve to avoid collinearity problems
      BMt=mean((Z1-Z0hat)*b)             //this is the bias adjustment
      BCM=NN-BMt                         //this is bias-corrected matching
      //put back into Stata
      st_local("BCM",strofreal(BCM,"%10.0g"))
    }
  }

end

